{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "This is common code used by multiple code snippets\n",
    "in this chapter. We have factored it out. It will\n",
    "be presented only once here.\n",
    "\"\"\"\n",
    "import numpy as np\n",
    "import torch\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def update_parameters(params, learning_rate):\n",
    "    \"\"\"\n",
    "    Update the current weight and bias values\n",
    "    from gradient values.\n",
    "    \"\"\"\n",
    "    # Don't track gradients while updating params\n",
    "    with torch.no_grad():\n",
    "        for i, p in enumerate(params):\n",
    "            params[i] = p - learning_rate * p.grad\n",
    "            \n",
    "    # Restore tracking of gradient for all params\n",
    "    for i in range(len(params)):\n",
    "        params[i].requires_grad = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def draw_line(m, c, min_x=0, max_x=10,\n",
    "              color='magenta', label=None):\n",
    "    \"\"\"\n",
    "    Plots y = mx + c from interval (min_x to max_x)\n",
    "    \"\"\"\n",
    "    # linspace creates an array of equally spaced\n",
    "    # values between the specified min and max in \n",
    "    # specified number of steps.\n",
    "    x = np.linspace(min_x, max_x, 100)\n",
    "    y = m*x + c\n",
    "    \n",
    "    plt.plot(x, y, color=color, \n",
    "             label='y=%0.2fx+%0.2f'%(m, c)\\\n",
    "                 if not label else label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def draw_parabola(w0, w1, w2,  min_x=0, max_x=10,\n",
    "                  color='magenta', label=None):\n",
    "    \"\"\"\n",
    "    Plots y = w0 + w1*x + w2*x^2 from interval\n",
    "    (min_x to max_x)\n",
    "    \"\"\"\n",
    "    x = np.linspace(min_x, max_x, 100)\n",
    "    y = w0 + w1*x +  w2* (x**2)\n",
    "    plt.plot(x, y, color=color, \n",
    "             label='y=%0.2f+ %0.2fx + %0.2fx^2'\n",
    "             %(w0, w1, w2) if not label else label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def draw_subplot(pos, step,\n",
    "                 true_draw_func, true_draw_params,\n",
    "                 pred_draw_func, pred_draw_params):\n",
    "    \"\"\"\n",
    "    Plots the curves corresponding to a specified pair\n",
    "    of functions.\n",
    "    We use it to plot\n",
    "    (i) the true function (used to generate the observations\n",
    "        that we are trying to predict with a trained mode)\n",
    "    vis a vis\n",
    "    (ii) the model function (used to makes the predictions)\n",
    "         When the predictor is good, the two plots should\n",
    "         more or less coincide.\n",
    "    Thus this is used to visualize the goodness of the\n",
    "    current approximation.\n",
    "    \"\"\"\n",
    "    plt.subplot(2, 2, pos)\n",
    "    plt.title('Step %d'%(step))\n",
    "    true_draw_func(**true_draw_params)\n",
    "    pred_draw_func(**pred_draw_params)\n",
    "    plt.xlabel('x')\n",
    "    plt.ylabel('y')\n",
    "    plt.legend(loc='upper left')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
